/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file dropout.cc
 * \brief
 * \author Bing Xu, Da Zheng, Hang Zhang
 */

#include "./dropout-inl.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(Dropout)
    .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
                                       [](const NodeAttrs& attrs, const bool is_train) {
                                         // Dropout is a passthrough during inference for all impls
                                         if (!is_train)
                                           return true;
#if MXNET_USE_CUDNN_DROPOUT
                                         // cuDNN impl is compatible during training as well
                                         const DropoutParam& param =
                                             nnvm::get<DropoutParam>(attrs.parsed);
                                         real_t pkeep = 1.0f - param.p;
                                         bool cudnn_off =
                                             param.cudnn_off && param.cudnn_off.value();
                                         bool cudnn_available = pkeep > 0 && !cudnn_off;
                                         return cudnn_available;
#else
                                         return false;
#endif  // MXNET_USE_CUDNN_DROPOUT
                                       })
    .set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutCompute<gpu>);

NNVM_REGISTER_OP(_backward_Dropout)
    .set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutGradCompute<gpu>);

}  // namespace op
}  // namespace mxnet
